{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://huggingface.co/yuanzhoulvpi/gpt2_chinese\n",
    "# https://huggingface.co/google/vit-base-patch16-224\n",
    "# https://huggingface.co/nlpconnect/vit-gpt2-image-captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import (VisionEncoderDecoderModel,\n",
    "                          ViTModel,GPT2LMHeadModel,\n",
    "                          AutoTokenizer,ViTImageProcessor,\n",
    "                          Trainer,TrainingArguments)\n",
    "from typing import List, Any \n",
    "import torch\n",
    "from torch import Tensor\n",
    "from PIL import Image\n",
    "from datasets import load_dataset,Dataset\n",
    "\n",
    "from tqdm import tqdm \n",
    "import numpy as np \n",
    "import pandas as pd "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']\n",
      "- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at yuanzhoulvpi/gpt2_chinese and are newly initialized: ['transformer.h.6.crossattention.q_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.masked_bias', 'transformer.h.9.crossattention.q_attn.weight', 'transformer.h.7.crossattention.c_proj.weight', 'transformer.h.8.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_proj.bias', 'transformer.h.8.crossattention.c_proj.weight', 'transformer.h.10.crossattention.c_attn.weight', 'transformer.h.6.ln_cross_attn.weight', 'transformer.h.5.crossattention.c_proj.weight', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.10.crossattention.masked_bias', 'transformer.h.8.crossattention.c_proj.bias', 'transformer.h.11.crossattention.bias', 'transformer.h.8.crossattention.masked_bias', 'transformer.h.7.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.5.ln_cross_attn.weight', 'transformer.h.6.crossattention.c_attn.weight', 'transformer.h.1.crossattention.bias', 'transformer.h.3.crossattention.q_attn.weight', 'transformer.h.6.crossattention.masked_bias', 'transformer.h.5.crossattention.q_attn.weight', 'transformer.h.5.crossattention.bias', 'transformer.h.10.crossattention.bias', 'transformer.h.2.crossattention.masked_bias', 'transformer.h.11.ln_cross_attn.weight', 'transformer.h.9.ln_cross_attn.weight', 'transformer.h.4.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.10.ln_cross_attn.weight', 'transformer.h.0.crossattention.bias', 'transformer.h.6.crossattention.c_proj.bias', 'transformer.h.10.crossattention.c_proj.weight', 'transformer.h.4.crossattention.q_attn.weight', 'transformer.h.11.crossattention.masked_bias', 'transformer.h.2.crossattention.q_attn.weight', 'transformer.h.7.ln_cross_attn.weight', 'transformer.h.4.ln_cross_attn.weight', 'transformer.h.9.crossattention.c_proj.weight', 'transformer.h.4.crossattention.c_proj.weight', 'transformer.h.11.crossattention.q_attn.weight', 'transformer.h.0.crossattention.masked_bias', 'transformer.h.11.crossattention.c_attn.weight', 'transformer.h.10.crossattention.q_attn.weight', 'transformer.h.3.ln_cross_attn.weight', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.7.crossattention.q_attn.weight', 'transformer.h.9.crossattention.masked_bias', 'transformer.h.7.crossattention.c_proj.bias', 'transformer.h.8.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_proj.bias', 'transformer.h.9.crossattention.bias', 'transformer.h.9.crossattention.c_attn.weight', 'transformer.h.2.crossattention.bias', 'transformer.h.3.crossattention.c_attn.weight', 'transformer.h.5.crossattention.masked_bias', 'transformer.h.2.ln_cross_attn.weight', 'transformer.h.3.crossattention.c_proj.weight', 'transformer.h.4.crossattention.masked_bias', 'transformer.h.4.crossattention.bias', 'transformer.h.3.crossattention.masked_bias', 'transformer.h.3.crossattention.c_proj.bias', 'transformer.h.6.crossattention.c_proj.weight', 'transformer.h.2.crossattention.c_proj.weight', 'transformer.h.3.crossattention.bias', 'transformer.h.11.crossattention.c_proj.bias', 'transformer.h.9.crossattention.c_proj.bias', 'transformer.h.10.crossattention.c_proj.bias', 'transformer.h.8.crossattention.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.7.crossattention.masked_bias', 'transformer.h.11.crossattention.c_proj.weight', 'transformer.h.4.crossattention.c_proj.bias', 'transformer.h.7.crossattention.bias', 'transformer.h.8.crossattention.c_attn.weight', 'transformer.h.6.crossattention.bias', 'transformer.h.5.crossattention.c_attn.weight', 'transformer.h.1.crossattention.q_attn.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "VIT_MODEL_NAME_OR_PATH = \"google/vit-base-patch16-224\"\n",
    "GPT_MODEL_NAME_OR_PATH = \"yuanzhoulvpi/gpt2_chinese\"\n",
    "\n",
    "\n",
    "VIT_model = ViTModel.from_pretrained(VIT_MODEL_NAME_OR_PATH)\n",
    "GPT_model = GPT2LMHeadModel.from_pretrained(GPT_MODEL_NAME_OR_PATH, add_cross_attention=True)\n",
    "\n",
    "GPT_model.config.add_cross_attention# = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = ViTImageProcessor.from_pretrained(VIT_MODEL_NAME_OR_PATH)\n",
    "tokenizer = AutoTokenizer.from_pretrained(GPT_MODEL_NAME_OR_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 224, 224])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def process_image_2_pixel_value(x:str) -> Tensor:\n",
    "    image = Image.open(x)\n",
    "    res = processor(images=image, return_tensors='pt')['pixel_values'].squeeze(0)\n",
    "    return res \n",
    "\n",
    "\n",
    "process_image_2_pixel_value(x = \"bigdata/image_data/test-9282.jpg\").shape "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def process_text_2_input_id(x:str) :\n",
    "    res = tokenizer(text=x,max_length=100, truncation=True,padding=\"max_length\")['input_ids']\n",
    "    return res \n",
    "\n",
    "len(process_text_2_input_id(x='hhh'))\n",
    "\n",
    "# len(process_text_2_input_id(x=\"你好啊，csdhhchsh谁cdshhchshcsdhhhhhhhh\")['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21128"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GPT_model.config.add_cross_attention = True\n",
    "# # GPT_model.crossattention = False\n",
    "# GPT_model.config.add_cross_attention\n",
    "\n",
    "# # config.add_cross_attention=True\n",
    "# # hasattr(GPT_model, \"crossattention\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_encoder_decoder_model = VisionEncoderDecoderModel(\n",
    "    encoder=VIT_model,\n",
    "    decoder=GPT_model,\n",
    "    \n",
    ")\n",
    "# new_encoder_decoder_model.config.use_return_dict = False\n",
    "new_encoder_decoder_model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
    "new_encoder_decoder_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "\n",
    "# new_encoder_decoder_model.decoder.config.add_cross_attention=True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 100])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.tensor(process_text_2_input_id(x='hhh'), dtype=torch.long).unsqueeze(0).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_encoder_decoder_model.config.add_cross_attention = True\n",
    "new_encoder_decoder_model.config.add_cross_attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0a653acd0a64180b77ac58641f44b68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1308 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6235c64893894e40969f7d736cf61c06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['image_path', 'text', 'labels'],\n",
       "        num_rows: 1307767\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['image_path', 'text', 'labels'],\n",
       "        num_rows: 1310\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = Dataset.from_pandas(df=pd.read_csv(\"bigdata/clean_train_test/train.csv\"))\n",
    "dataset = dataset.train_test_split(test_size=0.001)\n",
    "\n",
    "\n",
    "def tokenizer_text(examples) :\n",
    "    examples['labels'] = [process_text_2_input_id(i) for i in examples['text']]\n",
    "    # res = [process_text_2_input_id(i) for i in examples['text']]\n",
    "    # examples['labels'] = [i['input_ids'] for i in res]\n",
    "    return examples\n",
    "\n",
    "def transform_images(examples):\n",
    "    images = [process_image_2_pixel_value(i) for i in examples['image_path']]\n",
    "    # images = [torch.Tensor(i) for i in images]\n",
    "    examples['pixel_values'] = images\n",
    "    return examples\n",
    "\n",
    "dataset = dataset.map(\n",
    "    function=tokenizer_text,\n",
    "    batched=True\n",
    ")\n",
    "# dataset = dataset.map(\n",
    "#     function=transform_images,\n",
    "#     batched=True\n",
    "# )\n",
    "\n",
    "dataset.set_transform(transform=transform_images)\n",
    "\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cuda_amp half precision backend\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "***** Running training *****\n",
      "  Num examples = 1307767\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 48\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 384\n",
      "  Gradient Accumulation steps = 8\n",
      "  Total optimization steps = 3405\n",
      "  Number of trainable parameters = 216825600\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02d87be4e8ec4c8baaafdfcd51b9019f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3405 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.6245, 'learning_rate': 0.0004831670371886689, 'epoch': 0.12}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "949ad4c788a343c6ac653e71084fc3f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-400\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-400\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-400\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 1.1368376016616821, 'eval_runtime': 66.9577, 'eval_samples_per_second': 19.565, 'eval_steps_per_second': 0.612, 'epoch': 0.12}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-400\\pytorch_model.bin\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (98988480 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.0244, 'learning_rate': 0.0004349349378507369, 'epoch': 0.23}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84e037bcbe6e484cae92240cb00f094f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-800\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-800\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-800\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.934100866317749, 'eval_runtime': 37.9955, 'eval_samples_per_second': 34.478, 'eval_steps_per_second': 1.079, 'epoch': 0.23}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-800\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-500] due to args.save_total_limit\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.9268, 'learning_rate': 0.0003617988150619466, 'epoch': 0.35}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0713c04c9334a76a2ed3119a704102b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-1200\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-1200\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-1200\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8836056590080261, 'eval_runtime': 40.0826, 'eval_samples_per_second': 32.683, 'eval_steps_per_second': 1.023, 'epoch': 0.35}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-1200\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-1000] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (96407014 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8801, 'learning_rate': 0.0002736074499028474, 'epoch': 0.47}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51b266bcdbeb4cd5896e324bc463bd18",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-1600\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-1600\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-1600\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8513150215148926, 'eval_runtime': 33.9441, 'eval_samples_per_second': 38.593, 'eval_steps_per_second': 1.208, 'epoch': 0.47}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-1600\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-1500] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (101756928 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8534, 'learning_rate': 0.00018223701813346817, 'epoch': 0.59}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5ae4a462ea0a4f1fbd2e6cb934e93c2d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-2000\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2000\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2000\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8301102519035339, 'eval_runtime': 31.7757, 'eval_samples_per_second': 41.226, 'eval_steps_per_second': 1.29, 'epoch': 0.59}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-2000\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-400] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (107628420 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (101756928 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8339, 'learning_rate': 9.999180039404274e-05, 'epoch': 0.7}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ef1a7ac4a6d47c3ab2d278e2b8a70b0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-2400\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2400\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2400\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8126574158668518, 'eval_runtime': 33.2058, 'eval_samples_per_second': 39.451, 'eval_steps_per_second': 1.235, 'epoch': 0.7}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-2400\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-800] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (138921744 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8214, 'learning_rate': 3.794724221751192e-05, 'epoch': 0.82}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "810766f0fecc4639bb7ad8006ffa6812",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-2800\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2800\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-2800\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8048203587532043, 'eval_runtime': 33.377, 'eval_samples_per_second': 39.249, 'eval_steps_per_second': 1.228, 'epoch': 0.82}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-2800\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-1200] due to args.save_total_limit\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\TiffImagePlugin.py:850: UserWarning: Truncated File Read\n",
      "  warnings.warn(str(msg))\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (96458040 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "c:\\Users\\yuanz\\anaconda3\\envs\\mynet2\\lib\\site-packages\\PIL\\Image.py:3074: DecompressionBombWarning: Image size (127792080 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
      "  warnings.warn(\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 1310\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8142, 'learning_rate': 4.4584935273235815e-06, 'epoch': 0.94}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c79e75d5dd754fcb9bb5994acfffa4bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/41 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to vit-gpt2-image-chinese-captioning\\checkpoint-3200\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-3200\\config.json\n",
      "Configuration saved in vit-gpt2-image-chinese-captioning\\checkpoint-3200\\generation_config.json\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.8004471659660339, 'eval_runtime': 27.0937, 'eval_samples_per_second': 48.351, 'eval_steps_per_second': 1.513, 'epoch': 0.94}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Model weights saved in vit-gpt2-image-chinese-captioning\\checkpoint-3200\\pytorch_model.bin\n",
      "Deleting older checkpoint [vit-gpt2-image-chinese-captioning\\checkpoint-1600] due to args.save_total_limit\n"
     ]
    }
   ],
   "source": [
    "def collate_fn(examples):\n",
    "    pixel_values = torch.stack([i['pixel_values'] for i in examples])\n",
    "    labels = torch.tensor([example[\"labels\"] for example in examples], dtype=torch.long)\n",
    "    return {\n",
    "        \"pixel_values\": pixel_values,\n",
    "        \"labels\": labels\n",
    "    }\n",
    "\n",
    "\n",
    "train_argument = TrainingArguments(\n",
    "    output_dir=\"vit-gpt2-image-chinese-captioning\",\n",
    "    per_device_train_batch_size=48,\n",
    "    per_device_eval_batch_size=32,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    eval_steps=400,\n",
    "    logging_steps=400,\n",
    "    gradient_accumulation_steps=8,\n",
    "    num_train_epochs=1,\n",
    "    weight_decay=0.1,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    learning_rate=5e-4,\n",
    "    save_steps=400,\n",
    "    fp16=True,\n",
    "    remove_unused_columns=False,\n",
    "    save_total_limit=4\n",
    "\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=new_encoder_decoder_model,\n",
    "    args=train_argument,\n",
    "    train_dataset=dataset['train'],\n",
    "    eval_dataset=dataset['test'],\n",
    "    data_collator=collate_fn,\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'   '"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mynet2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "110bc624a448454d574a0cd6cc76359fd86f75739e493913b3d71c2e04f2ffdb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
